#!/usr/bin/env python3
from params_proto import ParamsProto
from params_proto import Proto


class Args(ParamsProto):
    algorithm = 'lops-aps-ase' #[lops-il, lops-lambda,lops-aps-ase, lops-aps, mamba, pg-gae, aggrevated, rpi], lops-lambda: need to set "lmd" and "maxplus_switch_rl_round"
    
    ase_uncertainty= 'next_state_mean' #[value_std, value_max_gap, next_state_mean, euclidean_mean, wasserstein_mean]
    available_gpu= [0,2,3]
    num_train_steps = 100
    gamma = 1.
    lmd = 0.9 #0:max-aggregate IL 1:RL,  lops-il: lmd=0, lops-lambda round < maxplus_switch_rl_round, otherwise, lmd=1

    load_expert_step = None  # DEPRECATED

    experts_info = []  # [('sac', path-to-model), ('ppo', path-to-model), blah]

    max_grad_norm = 1.0

    num_rollouts = 64  # 8
    num_eval_episodes = 32  # 8
    # num_updates = 32  # 100
    num_epochs = 4
    seed = 0
    use_ase_sigma_coef = False
    ase_sigma = 10.
    use_ase_sigma_ratio = True
    ase_sigma_ratio = 0.1
    ase_sigma_coef = 1.
    batch_size = 128

    reset_expert_vfn = True
    expert_vfn_gain = 1.0
    num_expert_vfns = 5

    pret_num_rollouts = 16  # 16
    # pret_num_updates = 100
    pret_num_epochs = 30
    pret_num_val_iterations = 4

    state_pred_num_epochs = 8

    learner_buffer_size = 2048
    expert_buffer_size = 8192

    use_ppo_loss = False
    use_expert_obsnormalizer = True

    expert_tgtval = 'monte-carlo'

    # Compute stddev and variance from predicted means.
    # Prediction of stddev is ignored if set to True
    std_from_means = False

    deterministic_experts = False

    saved_states_dir = Proto(env='SAVED_STATES_DIR')
    experts_dir = Proto(env='EXPERTS_DIR')

    env_name = 'DartCartPole-v1'  # 'dmc:Cheetah-run-v1'
    max_episode_len = 300

    use_riro_for_learner_pi = 'none'  # 'none', 'rollin', 'all'
    aps_expert="ucb" #['ucb', 'lcb', 'mean', "NONE"]
    aps_learner="lcb" #['ucb', 'lcb', 'mean', "NONE"]
    maxplus_expert="mean" #['ucb', 'lcb', 'mean', "NONE"]
    maxplus_learner="mean" #['ucb', 'lcb', 'mean', "NONE"]
    maxplus_switch_rl_round=999999 # a value <= totol training steps  #this also could used by lops
    group="default" # group name for saving results
    state_in_distribution = 999999 #value measures difference between upper-lower in MaxValFnPlus
    explore_decay_rate=0 #MaxValFnPlus exploration of experts decay rate. 0 means no decay factor